{ "cells": [ { "cell_type": "markdown", "id": "a1b2c3d4", "metadata": {}, "source": [ "# Simple Microscope Simulation - Biological Cells\n", "\n", "This notebook demonstrates simulating microscope imaging of multiple biological cells using the `simple_microscope` function.\n", "\n", "## Overview\n", "\n", "We create a 2D sample with ~100 biological cells:\n", "- Uniformly distributed across the field of view (grid + jitter)\n", "- Random cell and nucleus sizes\n", "- Each cell has unique refractive index (clustered values)\n", "- Size-dependent transparency: bigger cells are more transparent\n", "- Projected to 2D transmission function for imaging using vmap" ] }, { "cell_type": "markdown", "id": "e5f6g7h8", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "i9j0k1l2", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:06.183098Z", "iopub.status.busy": "2025-12-10T18:57:06.182805Z", "iopub.status.idle": "2025-12-10T18:57:07.183151Z", "shell.execute_reply": "2025-12-10T18:57:07.182185Z" } }, "outputs": [], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4,5,6,7\" # Skip GPU 0\n", "\n", "import janssen as jns\n", "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import cmocean.cm as cmo\n", "from matplotlib_scalebar.scalebar import ScaleBar\n", "from matplotlib.patches import Rectangle\n", "import matplotlib as mpl\n", "\n", "print(f\"JAX devices: {jax.devices()}\")\n", "print(f\"Number of GPUs available: {jax.device_count()}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b7587edb", "metadata": {}, "outputs": [], "source": [ "mpl.rcParams[\"animation.embed_limit\"] = 500" ] }, { "cell_type": "code", "execution_count": null, "id": "6ac3c69b", "metadata": {}, "outputs": [], "source": [ "jns.__version__" ] }, { "cell_type": "code", "execution_count": null, "id": "m3n4o5p6", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:07.186165Z", "iopub.status.busy": "2025-12-10T18:57:07.185946Z", "iopub.status.idle": "2025-12-10T18:57:07.224082Z", "shell.execute_reply": "2025-12-10T18:57:07.223152Z" } }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "q7r8s9t0", "metadata": {}, "source": [ "## Define Simulation Parameters\n", "\n", "We create a sample with 4096x4096 pixels at 0.25 µm pixel size to contain many cells." ] }, { "cell_type": "code", "execution_count": null, "id": "u1v2w3x4", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:07.226575Z", "iopub.status.busy": "2025-12-10T18:57:07.226427Z", "iopub.status.idle": "2025-12-10T18:57:07.241702Z", "shell.execute_reply": "2025-12-10T18:57:07.241012Z" } }, "outputs": [], "source": [ "pixel_size = 0.25e-6 # 0.25 microns\n", "num_pixels = 4096 # Large grid for many cells\n", "wavelength = 633e-9 # 633 nm (HeNe laser)\n", "\n", "# Cell parameters\n", "num_cells = 100\n", "min_cell_radius_um = 8 # Minimum cell radius in microns\n", "max_cell_radius_um = (\n", " 100 # Maximum cell radius in microns (4x larger than before)\n", ")\n", "\n", "# Convert to pixels\n", "min_cell_radius_pixels = min_cell_radius_um * 1e-6 / pixel_size\n", "max_cell_radius_pixels = max_cell_radius_um * 1e-6 / pixel_size\n", "\n", "# Nucleus is typically 30-40% of cell radius\n", "nucleus_ratio = 0.35\n", "\n", "# Refractive indices (typical values for biological tissue)\n", "n_medium = 1.337 + 0.0j # Water\n", "\n", "# Random seed for reproducibility\n", "np.random.seed(42)\n", "\n", "print(f\"Pixel size: {pixel_size * 1e6:.2f} microns\")\n", "print(f\"Grid size: {num_pixels} x {num_pixels} pixels\")\n", "print(\n", " f\"Field of view: {pixel_size * num_pixels * 1e6:.0f} µm = {pixel_size * num_pixels * 1e3:.2f} mm\"\n", ")\n", "print(f\"Wavelength: {wavelength * 1e9:.0f} nm\")\n", "print(f\"Number of cells: {num_cells}\")\n", "print(\n", " f\"Cell radius range: {min_cell_radius_um}-{max_cell_radius_um} µm ({min_cell_radius_pixels:.0f}-{max_cell_radius_pixels:.0f} pixels)\"\n", ")" ] }, { "cell_type": "markdown", "id": "y5z6a7b8", "metadata": {}, "source": [ "## 1. Generate Random Cell Parameters\n", "\n", "Create random positions, radii, and refractive indices for all cells using uniform grid distribution with jitter." ] }, { "cell_type": "code", "execution_count": null, "id": "c9d0e1f2", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:07.244202Z", "iopub.status.busy": "2025-12-10T18:57:07.244072Z", "iopub.status.idle": "2025-12-10T18:57:50.508166Z", "shell.execute_reply": "2025-12-10T18:57:50.507318Z" } }, "outputs": [], "source": [ "# Generate uniformly distributed cell positions using grid + jitter\n", "grid_size = int(np.ceil(np.sqrt(num_cells))) # e.g., 10x10 grid for 100 cells\n", "\n", "# Cells can extend to edges and be cut off\n", "grid_spacing_y = num_pixels / grid_size\n", "grid_spacing_x = num_pixels / grid_size\n", "\n", "grid_y, grid_x = np.meshgrid(\n", " np.linspace(\n", " grid_spacing_y / 2, num_pixels - grid_spacing_y / 2, grid_size\n", " ),\n", " np.linspace(\n", " grid_spacing_x / 2, num_pixels - grid_spacing_x / 2, grid_size\n", " ),\n", ")\n", "grid_centers = np.stack([grid_y.ravel(), grid_x.ravel()], axis=1)\n", "\n", "# Randomly select num_cells positions from grid and add jitter\n", "selected_indices = np.random.choice(\n", " len(grid_centers), num_cells, replace=False\n", ")\n", "cell_centers = grid_centers[selected_indices]\n", "\n", "# Add jitter (up to 25% of grid spacing for more uniform look)\n", "jitter_amount = 0.25 * min(grid_spacing_y, grid_spacing_x)\n", "cell_centers_y = cell_centers[:, 0] + np.random.uniform(\n", " -jitter_amount, jitter_amount, num_cells\n", ")\n", "cell_centers_x = cell_centers[:, 1] + np.random.uniform(\n", " -jitter_amount, jitter_amount, num_cells\n", ")\n", "\n", "# Cell radii: uniform distribution between min and max\n", "cell_radii_pixels = np.random.uniform(\n", " min_cell_radius_pixels, max_cell_radius_pixels, num_cells\n", ")\n", "nucleus_radii_pixels = cell_radii_pixels * nucleus_ratio\n", "\n", "# Refractive index - unique for each cell, clustered values\n", "# Bigger cells are more transparent (smaller contrast), smaller cells less transparent\n", "radius_normalized = (cell_radii_pixels - min_cell_radius_pixels) / (\n", " max_cell_radius_pixels - min_cell_radius_pixels\n", ")\n", "\n", "# Cytoplasm refractive index (real part: 1.36-1.38 range)\n", "# Smaller cells: higher contrast, bigger cells: lower contrast\n", "base_n_cytoplasm_real = 1.355 + 0.015 * (\n", " 1 - radius_normalized\n", ") # Inversely proportional to size\n", "cytoplasm_n_real = base_n_cytoplasm_real + np.random.uniform(\n", " -0.002, 0.002, num_cells\n", ")\n", "\n", "# Cytoplasm absorption (imaginary part)\n", "base_n_cytoplasm_imag = 0.0005 + 0.002 * (1 - radius_normalized)\n", "cytoplasm_n_imag = base_n_cytoplasm_imag + np.random.uniform(\n", " -0.0002, 0.0002, num_cells\n", ")\n", "cytoplasm_n_imag = np.maximum(cytoplasm_n_imag, 0)\n", "\n", "cytoplasm_n = cytoplasm_n_real + 1j * cytoplasm_n_imag\n", "\n", "# Nucleus refractive index (higher than cytoplasm)\n", "nucleus_n_real = (\n", " cytoplasm_n_real + 0.02 + np.random.uniform(-0.005, 0.005, num_cells)\n", ")\n", "nucleus_n_imag = cytoplasm_n_imag * 1.5 + np.random.uniform(\n", " -0.0003, 0.0003, num_cells\n", ")\n", "nucleus_n_imag = np.maximum(nucleus_n_imag, 0)\n", "\n", "nucleus_n = nucleus_n_real + 1j * nucleus_n_imag\n", "\n", "print(f\"Generated {num_cells} cells with uniform distribution\")\n", "print(\n", " f\"Grid: {grid_size}x{grid_size} = {grid_size**2} positions, selected {num_cells}\"\n", ")\n", "print(f\"Grid spacing: {grid_spacing_y:.0f} x {grid_spacing_x:.0f} pixels\")\n", "print(\n", " f\"Center Y range: {cell_centers_y.min():.0f} to {cell_centers_y.max():.0f} pixels\"\n", ")\n", "print(\n", " f\"Center X range: {cell_centers_x.min():.0f} to {cell_centers_x.max():.0f} pixels\"\n", ")\n", "print(\n", " f\"Cell radius range: {cell_radii_pixels.min():.0f} to {cell_radii_pixels.max():.0f} pixels\"\n", ")\n", "print(\n", " f\"Nucleus radius range: {nucleus_radii_pixels.min():.0f} to {nucleus_radii_pixels.max():.0f} pixels\"\n", ")\n", "print(\n", " f\"Cytoplasm n (real) range: {cytoplasm_n_real.min():.4f} to {cytoplasm_n_real.max():.4f}\"\n", ")\n", "print(\n", " f\"Nucleus n (real) range: {nucleus_n_real.min():.4f} to {nucleus_n_real.max():.4f}\"\n", ")" ] }, { "cell_type": "markdown", "id": "g3h4i5j6", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:50.510680Z", "iopub.status.busy": "2025-12-10T18:57:50.510540Z", "iopub.status.idle": "2025-12-10T18:57:55.853133Z", "shell.execute_reply": "2025-12-10T18:57:55.852271Z" } }, "source": [ "## 2. Create 2D Sample with Cell Projections using vmap\n", "\n", "For each cell, we compute the 2D projection (optical path length through a spherical cell with nucleus).\n", "The path length at position (x,y) for a sphere of radius R is:\n", "$L(x,y) = 2\\sqrt{R^2 - x^2 - y^2}$ for $x^2 + y^2 < R^2$" ] }, { "cell_type": "code", "execution_count": null, "id": "k7l8m9n0", "metadata": {}, "outputs": [], "source": [ "# Create 2D sample with projected cells on CPU (GPU OOM for large grids)\n", "# Force CPU for this cell to avoid GPU memory issues\n", "with jax.default_device(jax.devices('cpu')[0]):\n", " # Create coordinate grids\n", " y_coords = jnp.arange(num_pixels)\n", " x_coords = jnp.arange(num_pixels)\n", " yy, xx = jnp.meshgrid(y_coords, x_coords, indexing=\"ij\")\n", "\n", " # Wave number\n", " k = 2 * jnp.pi / wavelength\n", "\n", " # Convert cell parameters to JAX arrays\n", " centers_y = jnp.array(cell_centers_y)\n", " centers_x = jnp.array(cell_centers_x)\n", " cell_radii = jnp.array(cell_radii_pixels)\n", " nucleus_radii = jnp.array(nucleus_radii_pixels)\n", " cytoplasm_n_arr = jnp.array(cytoplasm_n)\n", " nucleus_n_arr = jnp.array(nucleus_n)\n", "\n", "\n", " def compute_cell_transmission(cy, cx, cell_radius, nuc_radius, n_cyto, n_nuc):\n", " \"\"\"Compute transmission contribution from a single cell with nucleus.\"\"\"\n", " # Distance from cell center (in pixels)\n", " dist_sq = (yy - cy) ** 2 + (xx - cx) ** 2\n", "\n", " # Path length through cell (in meters)\n", " cell_path_pixels = 2 * jnp.sqrt(jnp.maximum(cell_radius**2 - dist_sq, 0))\n", "\n", " # Path length through nucleus (in meters)\n", " nucleus_path_pixels = 2 * jnp.sqrt(jnp.maximum(nuc_radius**2 - dist_sq, 0))\n", "\n", " # Cytoplasm path = total cell path minus nucleus path\n", " cytoplasm_path_pixels = cell_path_pixels - nucleus_path_pixels\n", "\n", " # Convert to meters\n", " cytoplasm_path_meters = cytoplasm_path_pixels * pixel_size\n", " nucleus_path_meters = nucleus_path_pixels * pixel_size\n", "\n", " # Phase and amplitude from cytoplasm and nucleus\n", " delta_n_cyto = n_cyto - n_medium\n", " delta_n_nuc = n_nuc - n_medium\n", "\n", " # Total transmission = exp(i*k*(delta_n_cyto*L_cyto + delta_n_nuc*L_nuc))\n", " cell_transmission = jnp.exp(\n", " 1j\n", " * k\n", " * (\n", " delta_n_cyto * cytoplasm_path_meters\n", " + delta_n_nuc * nucleus_path_meters\n", " )\n", " )\n", "\n", " return cell_transmission\n", "\n", "\n", " # Process cells in batches (on CPU)\n", " batch_size = 10\n", " num_batches = num_cells // batch_size\n", "\n", " print(f\"Creating sample using batched vmap on CPU ({num_batches} batches of {batch_size} cells)...\")\n", " sample_transmission = jnp.ones((num_pixels, num_pixels), dtype=jnp.complex128)\n", "\n", " for batch_idx in range(num_batches):\n", " start_idx = batch_idx * batch_size\n", " end_idx = start_idx + batch_size\n", " \n", " # vmap over cells in this batch\n", " batch_transmissions = jax.vmap(compute_cell_transmission)(\n", " centers_y[start_idx:end_idx],\n", " centers_x[start_idx:end_idx],\n", " cell_radii[start_idx:end_idx],\n", " nucleus_radii[start_idx:end_idx],\n", " cytoplasm_n_arr[start_idx:end_idx],\n", " nucleus_n_arr[start_idx:end_idx],\n", " )\n", " \n", " # Multiply batch result into accumulated sample\n", " batch_product = jnp.prod(batch_transmissions, axis=0)\n", " sample_transmission = sample_transmission * batch_product\n", " \n", " print(f\" Processed batch {batch_idx+1}/{num_batches} ({end_idx} cells total)\")\n", "\n", "print(f\"Sample created using batched vmap!\")\n", "print(\n", " f\"Amplitude range: {jnp.abs(sample_transmission).min():.4f} to {jnp.abs(sample_transmission).max():.4f}\"\n", ")\n", "print(\n", " f\"Phase range: {jnp.angle(sample_transmission).min():.4f} to {jnp.angle(sample_transmission).max():.4f} rad\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "o1p2q3r4", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:55.855707Z", "iopub.status.busy": "2025-12-10T18:57:55.855574Z", "iopub.status.idle": "2025-12-10T18:57:56.855826Z", "shell.execute_reply": "2025-12-10T18:57:56.855161Z" } }, "outputs": [], "source": [ "# Create sample function\n", "cell_sample = jns.types.make_sample_function(\n", " sample=sample_transmission,\n", " dx=pixel_size,\n", ")\n", "\n", "print(f\"Sample shape: {cell_sample.sample.shape}\")\n", "print(f\"Sample dx: {cell_sample.dx * 1e6:.2f} microns\")" ] }, { "cell_type": "code", "execution_count": null, "id": "s5t6u7v8", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:56.858128Z", "iopub.status.busy": "2025-12-10T18:57:56.857985Z", "iopub.status.idle": "2025-12-10T18:57:57.310190Z", "shell.execute_reply": "2025-12-10T18:57:57.309424Z" } }, "outputs": [], "source": [ "# Visualize the sample\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", "amp = jnp.abs(cell_sample.sample)\n", "phase = jnp.angle(cell_sample.sample)\n", "\n", "# Amplitude\n", "im0 = axes[0].imshow(amp, cmap=cmo.gray)\n", "axes[0].set_title(\"Amplitude (Transmission)\")\n", "scalebar = ScaleBar(cell_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0], label=\"Transmission\")\n", "\n", "# Phase\n", "im1 = axes[1].imshow(phase, cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi)\n", "axes[1].set_title(\"Phase\")\n", "scalebar = ScaleBar(cell_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1], label=\"Phase (rad)\")\n", "\n", "plt.suptitle(f\"Sample with {num_cells} Biological Cells\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "w9x0y1z2", "metadata": {}, "source": [ "## 3. Create Illumination Wavefront" ] }, { "cell_type": "code", "execution_count": null, "id": "8befc548", "metadata": {}, "outputs": [], "source": [ "illumination_size = 512\n", "\n", "lightwave = jns.models.plane_wave(\n", " wavelength=wavelength,\n", " dx=pixel_size,\n", " grid_size=(illumination_size, illumination_size),\n", " amplitude=1.0,\n", ")\n", "\n", "print(f\"Illumination field shape: {lightwave.field.shape}\")\n", "print(f\"Illumination wavelength: {lightwave.wavelength * 1e9:.0f} nm\")\n", "print(f\"Illumination dx: {lightwave.dx * 1e6:.2f} microns\")\n", "print(f\"Illumination FOV: {illumination_size * pixel_size * 1e6:.0f} microns\")" ] }, { "cell_type": "markdown", "id": "i1j2k3l4", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:57.571201Z", "iopub.status.busy": "2025-12-10T18:57:57.571061Z", "iopub.status.idle": "2025-12-10T18:57:57.991693Z", "shell.execute_reply": "2025-12-10T18:57:57.991021Z" } }, "source": [ "## 4. Set Microscope Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "m5n6o7p8", "metadata": {}, "outputs": [], "source": [ "# Microscope parameters\n", "zoom_factor = 10.0 # 10x magnification\n", "aperture_diameter = 1e-3 # 1 mm aperture\n", "travel_distance = 0.15 # 150 mm to camera\n", "detector_pixel_size = jnp.array(16e-6) # 16 micron camera pixels\n", "\n", "print(f\"Zoom factor: {zoom_factor}x\")\n", "print(f\"Aperture diameter: {aperture_diameter * 1e3:.1f} mm\")\n", "print(f\"Travel distance: {travel_distance * 1e3:.0f} mm\")\n", "print(f\"Detector pixel size: {detector_pixel_size * 1e6:.1f} µm\")" ] }, { "cell_type": "markdown", "id": "0zecj8xhvfhg", "metadata": {}, "source": [ "## 5. Step-by-Step Diffractogram Formation\n", "\n", "Let's visualize each step in the formation of a diffractogram:\n", "1. **Linear Interaction** - Light interacts with the sample\n", "2. **Optical Zoom** - Magnification by the objective lens\n", "3. **Circular Aperture** - Limits the numerical aperture\n", "4. **Fraunhofer Propagation** - Far-field propagation to the camera" ] }, { "cell_type": "code", "execution_count": null, "id": "28xs9bzicmh", "metadata": {}, "outputs": [], "source": [ "# Cut sample at center for step-by-step visualization\n", "center_pixel = num_pixels // 2\n", "half_size = illumination_size // 2\n", "sample_cut = cell_sample.sample[\n", " center_pixel - half_size : center_pixel + half_size,\n", " center_pixel - half_size : center_pixel + half_size,\n", "]\n", "\n", "sample_region = jns.types.make_sample_function(\n", " sample=sample_cut,\n", " dx=pixel_size,\n", ")\n", "\n", "print(f\"Sample region shape: {sample_region.sample.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "go2najuittc", "metadata": {}, "outputs": [], "source": [ "# Step 1: Linear Interaction - Light through sample\n", "after_sample = jns.scopes.linear_interaction(\n", " sample=sample_region,\n", " light=lightwave,\n", ")\n", "\n", "print(f\"After sample field shape: {after_sample.field.shape}\")\n", "print(f\"After sample dx: {after_sample.dx * 1e6:.2f} microns\")\n", "\n", "# Visualize\n", "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", "\n", "im0 = axes[0].imshow(jnp.abs(sample_region.sample), cmap=cmo.gray)\n", "axes[0].set_title(\"Sample Region\")\n", "scalebar = ScaleBar(sample_region.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0])\n", "\n", "im1 = axes[1].imshow(jnp.abs(after_sample.field) ** 2, cmap=cmo.gray)\n", "axes[1].set_title(\"Field Intensity After Sample\")\n", "scalebar = ScaleBar(after_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1])\n", "\n", "im2 = axes[2].imshow(\n", " jnp.angle(after_sample.field), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi\n", ")\n", "axes[2].set_title(\"Field Phase After Sample\")\n", "scalebar = ScaleBar(after_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[2].add_artist(scalebar)\n", "axes[2].axis(\"off\")\n", "plt.colorbar(im2, ax=axes[2])\n", "\n", "plt.suptitle(\"Step 1: Linear Interaction\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "cngiimg5hko", "metadata": {}, "outputs": [], "source": [ "# Step 2: Optical Zoom - Magnification\n", "zoomed_wave = jns.prop.optical_zoom(after_sample, zoom_factor)\n", "\n", "print(f\"Before zoom dx: {after_sample.dx * 1e6:.2f} microns\")\n", "print(f\"After zoom dx: {zoomed_wave.dx * 1e6:.2f} microns\")\n", "print(f\"Magnification achieved: {zoomed_wave.dx / after_sample.dx:.1f}x\")\n", "\n", "# Visualize\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", "\n", "im0 = axes[0].imshow(jnp.abs(after_sample.field) ** 2, cmap=cmo.gray)\n", "axes[0].set_title(f\"Before Zoom (dx={after_sample.dx*1e6:.2f} µm)\")\n", "scalebar = ScaleBar(after_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0])\n", "\n", "im1 = axes[1].imshow(jnp.abs(zoomed_wave.field) ** 2, cmap=cmo.gray)\n", "axes[1].set_title(f\"After Zoom (dx={zoomed_wave.dx*1e6:.2f} µm)\")\n", "scalebar = ScaleBar(zoomed_wave.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1])\n", "\n", "plt.suptitle(\"Step 2: Optical Zoom (Magnification)\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "k72h7d1kqfm", "metadata": {}, "outputs": [], "source": [ "# Step 3: Circular Aperture - NA Limit\n", "after_aperture = jns.optics.circular_aperture(\n", " zoomed_wave,\n", " diameter=aperture_diameter,\n", ")\n", "\n", "print(f\"Aperture diameter: {aperture_diameter * 1e3:.1f} mm\")\n", "\n", "# Visualize\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", "\n", "im0 = axes[0].imshow(jnp.abs(zoomed_wave.field) ** 2, cmap=cmo.gray)\n", "axes[0].set_title(\"Before Aperture\")\n", "scalebar = ScaleBar(zoomed_wave.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0])\n", "\n", "im1 = axes[1].imshow(jnp.abs(after_aperture.field) ** 2, cmap=cmo.gray)\n", "axes[1].set_title(\"After Circular Aperture\")\n", "scalebar = ScaleBar(\n", " after_aperture.dx, \"m\", length_fraction=0.25, color=\"black\"\n", ")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1])\n", "\n", "plt.suptitle(\"Step 3: Circular Aperture\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "fca4yxju3z", "metadata": {}, "outputs": [], "source": [ "# Step 4: Fraunhofer Propagation - To Camera Plane\n", "at_camera = jns.prop.fraunhofer_prop_scaled(\n", " after_aperture, travel_distance, output_dx=detector_pixel_size\n", ")\n", "\n", "print(f\"Propagation distance: {travel_distance * 1e3:.0f} mm\")\n", "print(f\"Camera plane dx: {at_camera.dx * 1e6:.2f} microns\")\n", "\n", "# Visualize\n", "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", "\n", "im0 = axes[0].imshow(\n", " jns.optics.field_intensity(at_camera.field), cmap=cmo.haline\n", ")\n", "axes[0].set_title(\"Intensity at Camera (Linear)\")\n", "scalebar = ScaleBar(at_camera.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0])\n", "\n", "im1 = axes[1].imshow(\n", " jnp.log10(1 + jns.optics.field_intensity(at_camera.field)), cmap=cmo.haline\n", ")\n", "axes[1].set_title(\"Intensity at Camera (Log)\")\n", "scalebar = ScaleBar(at_camera.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1])\n", "\n", "im2 = axes[2].imshow(\n", " jnp.angle(at_camera.field), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi\n", ")\n", "axes[2].set_title(\"Phase at Camera\")\n", "scalebar = ScaleBar(at_camera.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[2].add_artist(scalebar)\n", "axes[2].axis(\"off\")\n", "plt.colorbar(im2, ax=axes[2])\n", "\n", "plt.suptitle(\"Step 4: Fraunhofer Propagation to Camera\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "k0va62lda5", "metadata": {}, "source": [ "## 6. Compare with simple_diffractogram\n", "\n", "Verify that the step-by-step approach matches the `simple_diffractogram` function." ] }, { "cell_type": "code", "execution_count": null, "id": "u3v4w5x6", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:58.133778Z", "iopub.status.busy": "2025-12-10T18:57:58.133440Z", "iopub.status.idle": "2025-12-10T18:57:58.425834Z", "shell.execute_reply": "2025-12-10T18:57:58.425056Z" } }, "outputs": [], "source": [ "# Generate single diffractogram using the combined function\n", "diffractogram = jns.scopes.simple_diffractogram(\n", " sample_cut=sample_region,\n", " lightwave=lightwave,\n", " zoom_factor=zoom_factor,\n", " aperture_diameter=aperture_diameter,\n", " travel_distance=travel_distance,\n", " camera_pixel_size=detector_pixel_size,\n", ")\n", "\n", "print(f\"Diffractogram shape: {diffractogram.image.shape}\")\n", "print(f\"Diffractogram dx: {diffractogram.dx * 1e6:.2f} µm\")" ] }, { "cell_type": "code", "execution_count": null, "id": "y7z8a9b0", "metadata": {}, "outputs": [], "source": [ "# Compare manual pipeline with simple_diffractogram\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", "# Manual pipeline result\n", "im0 = axes[0].imshow(\n", " jns.optics.field_intensity(at_camera.field), cmap=cmo.haline\n", ")\n", "axes[0].set_title(\"Manual Pipeline (Step by Step)\")\n", "scalebar = ScaleBar(at_camera.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0].add_artist(scalebar)\n", "axes[0].axis(\"off\")\n", "plt.colorbar(im0, ax=axes[0])\n", "\n", "# Combined function result\n", "im1 = axes[1].imshow(diffractogram.image, cmap=cmo.haline)\n", "axes[1].set_title(\"simple_diffractogram Result\")\n", "scalebar = ScaleBar(diffractogram.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1].add_artist(scalebar)\n", "axes[1].axis(\"off\")\n", "plt.colorbar(im1, ax=axes[1])\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Verify they match\n", "print(\n", " f\"Max difference: {jnp.max(jnp.abs(jns.optics.field_intensity(at_camera.field) - diffractogram.image)):.2e}\"\n", ")" ] }, { "cell_type": "markdown", "id": "09kyh5ouk98b", "metadata": {}, "source": [ "## 7. Full Microscope Simulation - Scanning\n", "\n", "Create scan positions and run the full microscope simulation." ] }, { "cell_type": "code", "execution_count": null, "id": "g5h6i7j8", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:59.208657Z", "iopub.status.busy": "2025-12-10T18:57:59.208523Z", "iopub.status.idle": "2025-12-10T18:57:59.508250Z", "shell.execute_reply": "2025-12-10T18:57:59.507629Z" } }, "outputs": [], "source": [ "# Create scan positions centered on a region with cells\n", "scan_step = 15e-6 # 15 micron step size (same as Spheres)\n", "scan_pixel = scan_step / cell_sample.dx\n", "\n", "# Center of the sample\n", "scope_center = jnp.array(\n", " [num_pixels // 2, num_pixels // 2]\n", ") # (x, y) in pixels\n", "\n", "num_scan_x = 20 # Full grid (will use multiple GPUs)\n", "num_scan_y = 20\n", "\n", "xx, yy = jnp.meshgrid(\n", " jnp.arange(num_scan_x) * scan_pixel - (num_scan_x - 1) * scan_pixel / 2,\n", " jnp.arange(num_scan_y) * scan_pixel - (num_scan_y - 1) * scan_pixel / 2,\n", ")\n", "x_positions = xx + scope_center[0]\n", "y_positions = yy + scope_center[1]\n", "positions = jnp.stack([x_positions.ravel(), y_positions.ravel()], axis=1)\n", "\n", "print(f\"Scan step: {scan_step * 1e6:.0f} µm ({scan_pixel:.1f} pixels)\")\n", "print(f\"Number of scan positions: {len(positions)}\")\n", "print(f\"Scan grid: {num_scan_x} x {num_scan_y}\")\n", "print(\n", " f\"Total scan area: {(num_scan_x-1) * scan_step * 1e6:.0f} x {(num_scan_y-1) * scan_step * 1e6:.0f} µm\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "k9l0m1n2", "metadata": {}, "outputs": [], "source": [ "# Visualize scan positions on sample\n", "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", "\n", "im = ax.imshow(jnp.abs(cell_sample.sample), cmap=cmo.gray)\n", "ax.set_title(\"Sample with Scan Positions\")\n", "scalebar = ScaleBar(cell_sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "ax.add_artist(scalebar)\n", "\n", "# Add scan positions as colored dots\n", "scatter = ax.scatter(\n", " positions[:, 0],\n", " positions[:, 1],\n", " c=jnp.arange(len(positions)),\n", " cmap=\"coolwarm\",\n", " s=10,\n", " alpha=0.7,\n", " marker=\"o\",\n", ")\n", "plt.colorbar(scatter, ax=ax, label=\"Scan position index\")\n", "ax.axis(\"off\")\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "96ec6fd0", "metadata": {}, "outputs": [], "source": [ "# Run simple_microscope with automatic multi-GPU sharding\n", "positions_meters = positions * cell_sample.dx\n", "\n", "print(f\"Running microscope simulation with {len(positions_meters)} scan positions...\")\n", "print(\"(Automatic GPU sharding across 7 devices)\")\n", "\n", "microscope_data = jns.scopes.simple_microscope(\n", " sample=cell_sample,\n", " positions=positions_meters,\n", " lightwave=lightwave,\n", " zoom_factor=zoom_factor,\n", " aperture_diameter=aperture_diameter,\n", " travel_distance=travel_distance,\n", " camera_pixel_size=detector_pixel_size,\n", ")\n", "\n", "print(f\"\\nMicroscope simulation complete!\")\n", "print(f\"Microscope data shape: {microscope_data.image_data.shape}\")\n", "print(f\"Number of diffractograms: {microscope_data.image_data.shape[0]}\")\n", "print(f\"Diffractogram size: {microscope_data.image_data.shape[1:]}\")\n", "print(f\"Camera pixel size: {microscope_data.dx * 1e6:.2f} µm\")" ] }, { "cell_type": "code", "execution_count": null, "id": "o3p4q5r6", "metadata": { "execution": { "iopub.execute_input": "2025-12-10T18:57:59.510661Z", "iopub.status.busy": "2025-12-10T18:57:59.510514Z", "iopub.status.idle": "2025-12-10T18:58:00.110360Z", "shell.execute_reply": "2025-12-10T18:58:00.109748Z" } }, "outputs": [], "source": [ "# Visualize a subset of diffractograms (9 evenly spaced)\n", "fig, axes = plt.subplots(3, 3, figsize=(12, 12))\n", "\n", "indices = jnp.linspace(0, len(positions) - 1, 9).astype(int)\n", "\n", "for i, ax in enumerate(axes.flat):\n", " idx = int(indices[i])\n", " im = ax.imshow(\n", " jnp.log10(microscope_data.image_data[idx] + 1e-10), cmap=cmo.haline\n", " )\n", " pos = positions[idx]\n", " ax.set_title(f\"Pos {idx}: ({pos[0]:.0f}, {pos[1]:.0f}) px\")\n", " scalebar = ScaleBar(\n", " microscope_data.dx, \"m\", length_fraction=0.25, color=\"black\"\n", " )\n", " ax.add_artist(scalebar)\n", " ax.axis(\"off\")\n", "\n", "plt.suptitle(\n", " \"Selected Diffractograms from Cells Sample (Log Scale)\", fontsize=14\n", ")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "044764e9", "metadata": {}, "source": [ "## 8. Gauss-Newton Ptychographic Reconstruction\n", "\n", "Now we'll reconstruct the sample from the diffractograms using the Gauss-Newton solver.\n", "\n", "The Gauss-Newton method solves the nonlinear least-squares problem:\n", "\n", "$$\\min_{\\text{sample}, \\text{probe}} \\frac{1}{2} \\sum_j \\left\\| \\sqrt{I_j^{\\text{meas}}} - \\sqrt{I_j^{\\text{pred}}(\\text{sample}, \\text{probe})} \\right\\|^2$$\n", "\n", "using trust-region optimization with Levenberg-Marquardt damping." ] }, { "cell_type": "code", "execution_count": null, "id": "48cbe8f0", "metadata": {}, "outputs": [], "source": [ "# Initialize reconstruction (automatic multi-GPU sharding)\n", "print(\"Initializing reconstruction...\")\n", "print(\"(Automatic GPU sharding across 7 devices)\")\n", "\n", "reconstruction = jns.invert.init_simple_microscope(\n", " experimental_data=microscope_data,\n", " probe_lightwave=lightwave,\n", " zoom_factor=zoom_factor,\n", " aperture_diameter=aperture_diameter,\n", " travel_distance=travel_distance,\n", " camera_pixel_size=detector_pixel_size,\n", ")\n", "\n", "print(f\"\\nInitial sample shape: {reconstruction.sample.sample.shape}\")\n", "print(f\"Initial probe shape: {reconstruction.lightwave.field.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1a43b6e0", "metadata": {}, "outputs": [], "source": [ "# Check what parameters will be used\n", "from janssen.invert.ptychography import optimal_cg_params\n", "cg_max, cg_tol = optimal_cg_params(microscope_data, reconstruction)\n", "print(f\"Auto-calculated: cg_maxiter={cg_max}, cg_tol={cg_tol:.2e}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7f3b4bfb", "metadata": {}, "outputs": [], "source": [ "# Run Gauss-Newton reconstruction with automatic CG parameter optimization\n", "import time\n", "\n", "num_iterations = 30\n", "\n", "print(f\"Running Gauss-Newton reconstruction ({num_iterations} iterations)...\")\n", "print(\"(Automatic CG parameter optimization, compilation warm-up, and GPU sharding)\")\n", "\n", "start_time = time.time()\n", "\n", "result = jns.invert.simple_microscope_gn(\n", " experimental_data=microscope_data,\n", " reconstruction=reconstruction,\n", " num_iterations=num_iterations,\n", ")\n", "\n", "_ = jax.block_until_ready(result.sample.sample)\n", "elapsed_time = time.time() - start_time\n", "\n", "print(f\"\\nReconstruction complete in {elapsed_time:.1f} seconds!\")\n", "print(f\"Final loss: {result.losses[-1, 1]:.6e}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7734ad24", "metadata": {}, "outputs": [], "source": [ "# Visualize reconstruction results\n", "fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", "\n", "# Get the reconstructed sample at the scan region\n", "# The reconstruction is only valid in the scanned region\n", "recon_sample = result.sample.sample\n", "recon_probe = result.lightwave.field\n", "\n", "# Ground truth (cut to match reconstruction size)\n", "# The reconstruction FOV is computed from scan positions\n", "axes[0, 0].imshow(jnp.abs(sample_region.sample), cmap=cmo.gray)\n", "axes[0, 0].set_title(\"Ground Truth - Amplitude (Scan Region)\")\n", "scalebar = ScaleBar(sample_region.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0, 0].add_artist(scalebar)\n", "axes[0, 0].axis(\"off\")\n", "\n", "axes[1, 0].imshow(\n", " jnp.angle(sample_region.sample), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi\n", ")\n", "axes[1, 0].set_title(\"Ground Truth - Phase (Scan Region)\")\n", "scalebar = ScaleBar(sample_region.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1, 0].add_artist(scalebar)\n", "axes[1, 0].axis(\"off\")\n", "\n", "# Reconstructed sample\n", "axes[0, 1].imshow(jnp.abs(recon_sample), cmap=cmo.gray)\n", "axes[0, 1].set_title(\"Reconstructed Sample - Amplitude\")\n", "scalebar = ScaleBar(result.sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0, 1].add_artist(scalebar)\n", "axes[0, 1].axis(\"off\")\n", "\n", "axes[1, 1].imshow(\n", " jnp.angle(recon_sample), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi\n", ")\n", "axes[1, 1].set_title(\"Reconstructed Sample - Phase\")\n", "scalebar = ScaleBar(result.sample.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1, 1].add_artist(scalebar)\n", "axes[1, 1].axis(\"off\")\n", "\n", "# Reconstructed probe\n", "axes[0, 2].imshow(jnp.abs(recon_probe), cmap=cmo.gray)\n", "axes[0, 2].set_title(\"Reconstructed Probe - Amplitude\")\n", "scalebar = ScaleBar(result.lightwave.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[0, 2].add_artist(scalebar)\n", "axes[0, 2].axis(\"off\")\n", "\n", "axes[1, 2].imshow(\n", " jnp.angle(recon_probe), cmap=cmo.phase, vmin=-jnp.pi, vmax=jnp.pi\n", ")\n", "axes[1, 2].set_title(\"Reconstructed Probe - Phase\")\n", "scalebar = ScaleBar(result.lightwave.dx, \"m\", length_fraction=0.25, color=\"black\")\n", "axes[1, 2].add_artist(scalebar)\n", "axes[1, 2].axis(\"off\")\n", "\n", "plt.suptitle(\n", " f\"Gauss-Newton Reconstruction ({num_iterations} iterations)\",\n", " fontsize=16,\n", ")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "5e781b05", "metadata": {}, "outputs": [], "source": [ "# Plot loss convergence\n", "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", "\n", "iterations = jnp.arange(len(result.losses))\n", "ax.semilogy(iterations, result.losses, \"o-\", linewidth=2, markersize=4)\n", "ax.set_xlabel(\"Iteration\", fontsize=12)\n", "ax.set_ylabel(\"Loss (log scale)\", fontsize=12)\n", "ax.set_title(\"Gauss-Newton Convergence\", fontsize=14)\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Add text with final loss\n", "final_loss = result.losses[-1]\n", "ax.text(\n", " 0.98,\n", " 0.98,\n", " f\"Final loss: {final_loss:.4e}\",\n", " transform=ax.transAxes,\n", " ha=\"right\",\n", " va=\"top\",\n", " fontsize=11,\n", " bbox=dict(boxstyle=\"round\", facecolor=\"wheat\", alpha=0.5),\n", ")\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(f\"Loss decreased from {result.losses[0]:.4e} to {result.losses[-1]:.4e}\")\n", "print(\n", " f\"Reduction: {(result.losses[0] - result.losses[-1]) / result.losses[0] * 100:.1f}%\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6f7a13df", "metadata": {}, "outputs": [], "source": [ "from matplotlib.animation import FuncAnimation\n", "from IPython.display import HTML\n", "\n", "# Create figure with two panels\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", "# Panel 1: Sample with current position marker\n", "amp = jnp.abs(cell_sample.sample)\n", "im_sample = axes[0].imshow(amp, cmap=cmo.gray)\n", "axes[0].set_title(\"Sample with Scan Position\")\n", "scalebar_sample = ScaleBar(\n", " cell_sample.dx, \"m\", length_fraction=0.25, color=\"black\"\n", ")\n", "axes[0].add_artist(scalebar_sample)\n", "axes[0].axis(\"off\")\n", "\n", "# Position marker (red dot)\n", "(position_marker,) = axes[0].plot([], [], \"ro\", markersize=10)\n", "\n", "# Panel 2: Diffractogram\n", "im_diffract = axes[1].imshow(\n", " jnp.log10(microscope_data.image_data[0] + 1e-10), cmap=cmo.haline\n", ")\n", "axes[1].set_title(\"Diffractogram (Log Scale)\")\n", "scalebar_diffract = ScaleBar(\n", " microscope_data.dx, \"m\", length_fraction=0.25, color=\"black\"\n", ")\n", "axes[1].add_artist(scalebar_diffract)\n", "axes[1].axis(\"off\")\n", "cbar = plt.colorbar(im_diffract, ax=axes[1], label=\"Log₁₀(Intensity)\")\n", "\n", "plt.tight_layout()\n", "\n", "\n", "def init():\n", " position_marker.set_data([], [])\n", " return [im_diffract, position_marker]\n", "\n", "\n", "def update(frame):\n", " # Update position marker\n", " pos = positions[frame]\n", " position_marker.set_data([pos[0]], [pos[1]])\n", "\n", " # Update diffractogram\n", " diffract_data = jnp.log10(microscope_data.image_data[frame] + 1e-10)\n", " im_diffract.set_array(diffract_data)\n", "\n", " # Update title with position info\n", " axes[0].set_title(\n", " f\"Sample with Scan Position ({frame+1}/{len(positions)})\"\n", " )\n", "\n", " return [im_diffract, position_marker]\n", "\n", "\n", "# Create animation at 5 fps (200ms per frame)\n", "anim = FuncAnimation(\n", " fig,\n", " update,\n", " frames=len(positions),\n", " init_func=init,\n", " blit=False,\n", " interval=200,\n", ")\n", "\n", "# Display animation\n", "plt.close(fig)\n", "HTML(anim.to_jshtml())" ] }, { "cell_type": "code", "execution_count": null, "id": "d9ff0193", "metadata": {}, "outputs": [], "source": [ "# Save the animation as MP4 video\n", "anim.save(\"cells_microscope_scan.mp4\", writer=\"ffmpeg\", fps=5, dpi=150)\n", "print(\"Video saved as 'cells_microscope_scan.mp4'\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }